#Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
#Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
#
#THE BSD LICENSE
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions
#are met:
#
#1. Redistributions of source code must retain the above copyright
#   notice, this list of conditions and the following disclaimer.
#2. Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in the
#   documentation and/or other materials provided with the distribution.
#
#THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
#IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
#OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
#IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
#INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
#NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
#DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
#THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
#(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
#THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#from ctypes import *
#from ctypes.util import find_library
from numpy import (float32, float64, uint8, int32, require)
#import ctypes
#import numpy as np
from ctypes import (Structure, c_char_p, c_int, c_float, c_uint, c_long,
                    c_void_p, cdll, POINTER)
from numpy.ctypeslib import ndpointer
import os
import sys

STRING = c_char_p


class CustomStructure(Structure):
    """
        This class extends the functionality of the ctype's structure
        class by adding custom default values to the fields and a way of translating
        field types.
    """
    _defaults_ = {}
    _translation_ = {}

    def __init__(self):
        Structure.__init__(self)
        self.__field_names = [ f for (f, t) in self._fields_]
        self.update(self._defaults_)

    def update(self, dict):
        for k, v in dict.items():
            if k in self.__field_names:
                setattr(self, k, self.__translate(k, v))
            else:
                raise KeyError('No such member: ' + k)

    def __getitem__(self, k):
        if k in self.__field_names:
            return self.__translate_back(k, getattr(self, k))

    def __setitem__(self, k, v):
        if k in self.__field_names:
            setattr(self, k, self.__translate(k, v))
        else:
            raise KeyError('No such member: ' + k)

    def keys(self):
        return self.__field_names

    def __translate(self, k, v):
        if k in self._translation_:
            if v in self._translation_[k]:
                return self._translation_[k][v]
        return v

    def __translate_back(self, k, v):
        if k in self._translation_:
            for tk, tv in self._translation_[k].items():
                if tv == v:
                    return tk
        return v


class FLANNParameters(CustomStructure):
    _fields_ = [
        ('algorithm', c_int),
        ('checks', c_int),
        ('eps', c_float),
        ('sorted', c_int),
        ('max_neighbors', c_int),
        ('cores', c_int),
        ('trees', c_int),
        ('leaf_max_size', c_int),
        ('branching', c_int),
        ('iterations', c_int),
        ('centers_init', c_int),
        ('cb_index', c_float),
        ('target_precision', c_float),
        ('build_weight', c_float),
        ('memory_weight', c_float),
        ('sample_fraction', c_float),
        ('table_number_', c_uint),
        ('key_size_', c_uint),
        ('multi_probe_level_', c_uint),
        ('log_level', c_int),
        ('random_seed', c_long),
    ]
    _defaults_ = {
        'algorithm' : 'kdtree',
        'checks' : 32,
        'eps' : 0.0,
        'sorted' : 1,
        'max_neighbors' : -1,
        'cores' : 0,
        'trees' : 1,
        'leaf_max_size' : 4,
        'branching' : 32,
        'iterations' : 5,
        'centers_init' : 'random',
        'cb_index' : 0.5,
        'target_precision' : 0.9,
        'build_weight' : 0.01,
        'memory_weight' : 0.0,
        'sample_fraction' : 0.1,
        'table_number_': 12,
        'key_size_': 20,
        'multi_probe_level_': 2,
        'log_level' : 'warning',
        'random_seed' : -1
    }
    _translation_ = {
        'algorithm'     : {'linear'    : 0, 'kdtree'    : 1, 'kmeans'    : 2, 'composite' : 3, 'kdtree_single' : 4, 'hierarchical': 5, 'lsh': 6, 'saved': 254, 'autotuned' : 255, 'default'   : 1},
        'centers_init'  : {'random'    : 0, 'gonzales'  : 1, 'kmeanspp'  : 2, 'default'   : 0},
        'log_level'     : {'none'      : 0, 'fatal'     : 1, 'error'     : 2, 'warning'   : 3, 'info'      : 4, 'default'   : 2}
    }


default_flags = ['C_CONTIGUOUS', 'ALIGNED']
allowed_types = [ float32, float64, uint8, int32]

FLANN_INDEX = c_void_p


def load_flann_library():

    root_dir = os.path.abspath(os.path.dirname(__file__))

    libnames = ['libflann.so']
    libdir = 'lib'
    if sys.platform == 'win32':
        libnames = ['flann.dll', 'libflann.dll']
    elif sys.platform == 'darwin':
        libnames = ['libflann.dylib']

    while root_dir is not None:
        for libname in libnames:
            try:
                #print 'Trying ',os.path.join(root_dir,'lib',libname)
                flannlib = cdll[os.path.join(root_dir, libdir, libname)]
                return flannlib
            except Exception:
                pass
            try:
                flannlib = cdll[os.path.join(root_dir, 'build', libdir, libname)]
                return flannlib
            except Exception:
                pass
        tmp = os.path.dirname(root_dir)
        if tmp == root_dir:
            root_dir = None
        else:
            root_dir = tmp

    # if we didn't find the library so far, try loading without
    # a full path as a last resort
    for libname in libnames:
        try:
            #print 'Trying',libname
            flannlib = cdll[libname]
            return flannlib
        except:
            pass

    return None

flannlib = load_flann_library()
if flannlib is None:
    raise ImportError('Cannot load dynamic library. Did you compile FLANN?')


class FlannLib(object):
    pass

flann = FlannLib()


flannlib.flann_log_verbosity.restype = None
flannlib.flann_log_verbosity.argtypes = [
    c_int  # level
]


flannlib.flann_set_distance_type.restype = None
flannlib.flann_set_distance_type.argtypes = [
    c_int,
    c_int,
]

type_mappings = ( ('float', 'float32'),
                  ('double', 'float64'),
                  ('byte', 'uint8'),
                  ('int', 'int32') )


def define_functions(str):
    for type in type_mappings:
        eval(compile(str % {'C': type[0], 'numpy': type[1]}, '<string>', 'exec'))

flann.build_index = {}
define_functions(r"""
flannlib.flann_build_index_%(C)s.restype = FLANN_INDEX
flannlib.flann_build_index_%(C)s.argtypes = [
        ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'),  # dataset
        c_int,  # rows
        c_int,  # cols
        POINTER(c_float),  # speedup
        POINTER(FLANNParameters)  # flann_params
]
flann.build_index[%(numpy)s] = flannlib.flann_build_index_%(C)s
""")

flann.save_index = {}
define_functions(r"""
flannlib.flann_save_index_%(C)s.restype = None
flannlib.flann_save_index_%(C)s.argtypes = [
        FLANN_INDEX,  # index_id
        c_char_p #filename
]
flann.save_index[%(numpy)s] = flannlib.flann_save_index_%(C)s
""")

flann.load_index = {}
define_functions(r"""
flannlib.flann_load_index_%(C)s.restype = FLANN_INDEX
flannlib.flann_load_index_%(C)s.argtypes = [
        c_char_p,  #filename
        ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'),  # dataset
        c_int,  # rows
        c_int,  # cols
]
flann.load_index[%(numpy)s] = flannlib.flann_load_index_%(C)s
""")

flann.used_memory = {}
define_functions(r"""
flannlib.flann_used_memory_%(C)s.restype = c_int
flannlib.flann_used_memory_%(C)s.argtypes = [
        FLANN_INDEX, # index_id
]
flann.used_memory[%(numpy)s] = flannlib.flann_used_memory_%(C)s
""")

flann.add_points = {}
define_functions(r"""
flannlib.flann_add_points_%(C)s.restype = None
flannlib.flann_add_points_%(C)s.argtypes = [ 
        FLANN_INDEX, # index_id
        ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset
        c_int, # rows
        c_int, # cols
        c_float, # rebuild_threshhold
]
flann.add_points[%(numpy)s] = flannlib.flann_add_points_%(C)s
""")

flann.remove_point = {}
define_functions(r"""
flannlib.flann_remove_point_%(C)s.restype = None
flannlib.flann_remove_point_%(C)s.argtypes = [ 
        FLANN_INDEX, # index_id
        c_uint, # point_id
]
flann.remove_point[%(numpy)s] = flannlib.flann_remove_point_%(C)s
""")

flann.find_nearest_neighbors = {}
define_functions(r"""
flannlib.flann_find_nearest_neighbors_%(C)s.restype = c_int
flannlib.flann_find_nearest_neighbors_%(C)s.argtypes = [
        ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'),  # dataset
        c_int,  # rows
        c_int,  # cols
        ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'),  # testset
        c_int,  # tcount
        ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'),  # result
        ndpointer(float32, ndim=2, flags='aligned, c_contiguous, writeable'),  # dists
        c_int,  # nn
        POINTER(FLANNParameters)  # flann_params
]
flann.find_nearest_neighbors[%(numpy)s] = flannlib.flann_find_nearest_neighbors_%(C)s
""")

# fix definition for the 'double' case

flannlib.flann_find_nearest_neighbors_double.restype = c_int
flannlib.flann_find_nearest_neighbors_double.argtypes = [
    ndpointer(float64, ndim=2, flags='aligned, c_contiguous'),  # dataset
    c_int,  # rows
    c_int,  # cols
    ndpointer(float64, ndim=2, flags='aligned, c_contiguous'),  # testset
    c_int,  # tcount
    ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'),  # result
    ndpointer(float64, ndim=2, flags='aligned, c_contiguous, writeable'),  # dists
    c_int,  # nn
    POINTER(FLANNParameters)  # flann_params
]
flann.find_nearest_neighbors[float64] = flannlib.flann_find_nearest_neighbors_double


flann.find_nearest_neighbors_index = {}
define_functions(r"""
flannlib.flann_find_nearest_neighbors_index_%(C)s.restype = c_int
flannlib.flann_find_nearest_neighbors_index_%(C)s.argtypes = [
        FLANN_INDEX,  # index_id
        ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'),  # testset
        c_int,  # tcount
        ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'),  # result
        ndpointer(float32, ndim=2, flags='aligned, c_contiguous, writeable'),  # dists
        c_int,  # nn
        POINTER(FLANNParameters) # flann_params
]
flann.find_nearest_neighbors_index[%(numpy)s] = flannlib.flann_find_nearest_neighbors_index_%(C)s
""")

flannlib.flann_find_nearest_neighbors_index_double.restype = c_int
flannlib.flann_find_nearest_neighbors_index_double.argtypes = [
    FLANN_INDEX,  # index_id
    ndpointer(float64, ndim=2, flags='aligned, c_contiguous'),  # testset
    c_int,  # tcount
    ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'),  # result
    ndpointer(float64, ndim=2, flags='aligned, c_contiguous, writeable'),  # dists
    c_int,  # nn
    POINTER(FLANNParameters)  # flann_params
]
flann.find_nearest_neighbors_index[float64] = flannlib.flann_find_nearest_neighbors_index_double

flann.radius_search = {}
define_functions(r"""
flannlib.flann_radius_search_%(C)s.restype = c_int
flannlib.flann_radius_search_%(C)s.argtypes = [
        FLANN_INDEX,  # index_id
        ndpointer(%(numpy)s, ndim=1, flags='aligned, c_contiguous'),  # query
        ndpointer(int32, ndim=1, flags='aligned, c_contiguous, writeable'),  # indices
        ndpointer(float32, ndim=1, flags='aligned, c_contiguous, writeable'),  # dists
        c_int,  # max_nn
        c_float,  # radius
        POINTER(FLANNParameters) # flann_params
]
flann.radius_search[%(numpy)s] = flannlib.flann_radius_search_%(C)s
""")

flannlib.flann_radius_search_double.restype = c_int
flannlib.flann_radius_search_double.argtypes = [
    FLANN_INDEX,  # index_id
    ndpointer(float64, ndim=1, flags='aligned, c_contiguous'),  # query
    ndpointer(int32, ndim=1, flags='aligned, c_contiguous, writeable'),  # indices
    ndpointer(float64, ndim=1, flags='aligned, c_contiguous, writeable'),  # dists
    c_int,  # max_nn
    c_float,  # radius
    POINTER(FLANNParameters)  # flann_params
]
flann.radius_search[float64] = flannlib.flann_radius_search_double


flann.compute_cluster_centers = {}
define_functions(r"""
flannlib.flann_compute_cluster_centers_%(C)s.restype = c_int
flannlib.flann_compute_cluster_centers_%(C)s.argtypes = [
        ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'),  # dataset
        c_int,  # rows
        c_int,  # cols
        c_int,  # clusters
        ndpointer(float32, flags='aligned, c_contiguous, writeable'),  # result
        POINTER(FLANNParameters)  # flann_params
]
flann.compute_cluster_centers[%(numpy)s] = flannlib.flann_compute_cluster_centers_%(C)s
""")
# double is an exception
flannlib.flann_compute_cluster_centers_double.restype = c_int
flannlib.flann_compute_cluster_centers_double.argtypes = [
    ndpointer(float64, ndim=2, flags='aligned, c_contiguous'),  # dataset
    c_int,  # rows
    c_int,  # cols
    c_int,  # clusters
    ndpointer(float64, flags='aligned, c_contiguous, writeable'),  # result
    POINTER(FLANNParameters)  # flann_params
]
flann.compute_cluster_centers[float64] = flannlib.flann_compute_cluster_centers_double


flann.free_index = {}
define_functions(r"""
flannlib.flann_free_index_%(C)s.restype = None
flannlib.flann_free_index_%(C)s.argtypes = [
        FLANN_INDEX,  # index_id
        POINTER(FLANNParameters) # flann_params
]
flann.free_index[%(numpy)s] = flannlib.flann_free_index_%(C)s
""")


def ensure_2d_array(arr, flags, **kwargs):
    arr = require(arr, requirements=flags, **kwargs)
    if len(arr.shape) == 1:
        arr = arr.reshape(-1, arr.size)
    return arr
